{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MetaDSL + UArray\n",
    "\n",
    "In this notebook, Hameer and I (Saul) try out some ways to have MetaDSL and UArray place nice together.\n",
    "\n",
    "First, let's installed the latest uarray"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting git+https://github.com/Quansight-Labs/uarray.git\n",
      "  Cloning https://github.com/Quansight-Labs/uarray.git to /private/var/folders/m7/t8dvwtnn32z84333p845tly40000gn/T/pip-req-build-fu55z_bi\n",
      "Building wheels for collected packages: uarray\n",
      "  Building wheel for uarray (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25h  Stored in directory: /private/var/folders/m7/t8dvwtnn32z84333p845tly40000gn/T/pip-ephem-wheel-cache-j7a800tz/wheels/3d/1a/bf/60f787ed8f0ac071de28d869eb644793856923ac4688c14544\n",
      "Successfully built uarray\n",
      "Installing collected packages: uarray\n",
      "  Found existing installation: uarray 0.4+167.g8bee995\n",
      "    Uninstalling uarray-0.4+167.g8bee995:\n",
      "      Successfully uninstalled uarray-0.4+167.g8bee995\n",
      "Successfully installed uarray-0.4+168.g345664c\n"
     ]
    }
   ],
   "source": [
    "!pip install -U git+https://github.com/Quansight-Labs/uarray.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, there are two ways of integrating:\n",
    "\n",
    "1. `metadsl -> uarray` Build metadsl graph of numpy expression using `metadsl.nonumpy.compat` and then execute them using a uarray backend\n",
    "2. `uarray -> metadsl` Build up a metadsl graph using `unumpy`, by defining a `metadsl` backend for `uarray`.\n",
    "\n",
    "Both can make sense, so let's do a simple implementation of both sides and see how we can use them togther. We will start with (2), adding a `metadsl` backend for `uarray`: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import uarray\n",
    "import unumpy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import metadsl.expressions\n",
    "import metadsl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from uarray.backend import TypeCheckBackend, register_backend\n",
    "\n",
    "MetaDSLBackend = TypeCheckBackend(\n",
    "    tuple(),\n",
    "    convertor=lambda x: x,\n",
    "    fallback_types=tuple()\n",
    ")\n",
    "register_backend(MetaDSLBackend)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's just register one call, the `arange` call, and dispatch this by creating a call graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "arange_call = metadsl.call(lambda *args: metadsl.instance_type(metadsl.Instance))(unumpy.arange)\n",
    "\n",
    "@uarray.multimethod(MetaDSLBackend, unumpy.arange)\n",
    "def aranged_call_wrapped(start, stop, stride):\n",
    "    return arange_call(metadsl.Instance(start), metadsl.Instance(stop), metadsl.Instance(stride))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can use the metadsl backend to create a graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Instance(__value__=Call(function=<function arange at 0x103b1c510>, args=(Instance(__value__=0), Instance(__value__=10), Instance(__value__=2))))\n"
     ]
    }
   ],
   "source": [
    "with uarray.set_backend(MetaDSLBackend):\n",
    "    print(unumpy.arange(0, 10, 2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And here is a simple versio of (1) which takes a metadsl graph of unumpy expressions and executes them:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import unumpy.numpy_backend\n",
    "\n",
    "def execute_graph(x):\n",
    "    \"\"\"\n",
    "    Test version of executing a graph, in the form of a single funtion with primitive args.\n",
    "    \"\"\"\n",
    "    val = x.__value__\n",
    "    fn = val.function\n",
    "    args = val.args\n",
    "    return fn(*(a.__value__ for a in args))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 3 5 7 9]\n"
     ]
    }
   ],
   "source": [
    "print(execute_graph(aranged_call_wrapped(1, 10, 2)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can use them together!\n",
    "\n",
    "First we make a graph, using unumpy expression, then we execute that graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Instance(__value__=Call(function=<function arange at 0x106e91730>, args=(Instance(__value__=0), Instance(__value__=10), Instance(__value__=2))))\n",
      "[0 2 4 6 8]\n"
     ]
    }
   ],
   "source": [
    "with uarray.set_backend(MetaDSLBackend, coerce=True):\n",
    "\n",
    "    x = unumpy.arange(0, 10, 2)\n",
    "    print(x)\n",
    "    with uarray.skip_backend(MetaDSLBackend):\n",
    "        print(execute_graph(x))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
